1   /*
2    * Copyright (C) 2010 The Guava Authors
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    * http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  
17  package com.google.common.collect.testing;
18  
19  import java.io.Serializable;
20  import java.util.AbstractSet;
21  import java.util.Collection;
22  import java.util.Comparator;
23  import java.util.Iterator;
24  import java.util.Map;
25  import java.util.NavigableMap;
26  import java.util.NavigableSet;
27  import java.util.Set;
28  import java.util.SortedMap;
29  import java.util.TreeMap;
30  
31  /**
32   * A wrapper around {@code TreeMap} that aggressively checks to see if keys are
33   * mutually comparable. This implementation passes the navigable map test
34   * suites.
35   *
36   * @author Louis Wasserman
37   */
38  public final class SafeTreeMap<K, V>
39      implements Serializable, NavigableMap<K, V> {
40    @SuppressWarnings("unchecked")
41    private static final Comparator<Object> NATURAL_ORDER = new Comparator<Object>() {
42      @Override public int compare(Object o1, Object o2) {
43        return ((Comparable<Object>) o1).compareTo(o2);
44      }
45    };
46    private final NavigableMap<K, V> delegate;
47  
48    public SafeTreeMap() {
49      this(new TreeMap<K, V>());
50    }
51  
52    public SafeTreeMap(Comparator<? super K> comparator) {
53      this(new TreeMap<K, V>(comparator));
54    }
55  
56    public SafeTreeMap(Map<? extends K, ? extends V> map) {
57      this(new TreeMap<K, V>(map));
58    }
59  
60    public SafeTreeMap(SortedMap<K, ? extends V> map) {
61      this(new TreeMap<K, V>(map));
62    }
63  
64    private SafeTreeMap(NavigableMap<K, V> delegate) {
65      this.delegate = delegate;
66      if (delegate == null) {
67        throw new NullPointerException();
68      }
69      for (K k : keySet()) {
70        checkValid(k);
71      }
72    }
73  
74    @Override public Entry<K, V> ceilingEntry(K key) {
75      return delegate.ceilingEntry(checkValid(key));
76    }
77  
78    @Override public K ceilingKey(K key) {
79      return delegate.ceilingKey(checkValid(key));
80    }
81  
82    @Override public void clear() {
83      delegate.clear();
84    }
85  
86    @SuppressWarnings("unchecked")
87    @Override public Comparator<? super K> comparator() {
88      Comparator<? super K> comparator = delegate.comparator();
89      if (comparator == null) {
90        comparator = (Comparator<? super K>) NATURAL_ORDER;
91      }
92      return comparator;
93    }
94  
95    @Override public boolean containsKey(Object key) {
96      try {
97        return delegate.containsKey(checkValid(key));
98      } catch (NullPointerException e) {
99        return false;
100     } catch (ClassCastException e) {
101       return false;
102     }
103   }
104 
105   @Override public boolean containsValue(Object value) {
106     return delegate.containsValue(value);
107   }
108 
109   @Override public NavigableSet<K> descendingKeySet() {
110     return delegate.descendingKeySet();
111   }
112 
113   @Override public NavigableMap<K, V> descendingMap() {
114     return new SafeTreeMap<K, V>(delegate.descendingMap());
115   }
116 
117   @Override public Set<Entry<K, V>> entrySet() {
118     return new AbstractSet<Entry<K, V>>() {
119       private Set<Entry<K, V>> delegate() {
120         return delegate.entrySet();
121       }
122 
123       @Override
124       public boolean contains(Object object) {
125         try {
126           return delegate().contains(object);
127         } catch (NullPointerException e) {
128           return false;
129         } catch (ClassCastException e) {
130           return false;
131         }
132       }
133 
134       @Override
135       public Iterator<Entry<K, V>> iterator() {
136         return delegate().iterator();
137       }
138 
139       @Override
140       public int size() {
141         return delegate().size();
142       }
143 
144       @Override
145       public boolean remove(Object o) {
146         return delegate().remove(o);
147       }
148 
149       @Override
150       public void clear() {
151         delegate().clear();
152       }
153     };
154   }
155 
156   @Override public Entry<K, V> firstEntry() {
157     return delegate.firstEntry();
158   }
159 
160   @Override public K firstKey() {
161     return delegate.firstKey();
162   }
163 
164   @Override public Entry<K, V> floorEntry(K key) {
165     return delegate.floorEntry(checkValid(key));
166   }
167 
168   @Override public K floorKey(K key) {
169     return delegate.floorKey(checkValid(key));
170   }
171 
172   @Override public V get(Object key) {
173     return delegate.get(checkValid(key));
174   }
175 
176   @Override public SortedMap<K, V> headMap(K toKey) {
177     return headMap(toKey, false);
178   }
179 
180   @Override public NavigableMap<K, V> headMap(K toKey, boolean inclusive) {
181     return new SafeTreeMap<K, V>(
182         delegate.headMap(checkValid(toKey), inclusive));
183   }
184 
185   @Override public Entry<K, V> higherEntry(K key) {
186     return delegate.higherEntry(checkValid(key));
187   }
188 
189   @Override public K higherKey(K key) {
190     return delegate.higherKey(checkValid(key));
191   }
192 
193   @Override public boolean isEmpty() {
194     return delegate.isEmpty();
195   }
196 
197   @Override public NavigableSet<K> keySet() {
198     return navigableKeySet();
199   }
200 
201   @Override public Entry<K, V> lastEntry() {
202     return delegate.lastEntry();
203   }
204 
205   @Override public K lastKey() {
206     return delegate.lastKey();
207   }
208 
209   @Override public Entry<K, V> lowerEntry(K key) {
210     return delegate.lowerEntry(checkValid(key));
211   }
212 
213   @Override public K lowerKey(K key) {
214     return delegate.lowerKey(checkValid(key));
215   }
216 
217   @Override public NavigableSet<K> navigableKeySet() {
218     return delegate.navigableKeySet();
219   }
220 
221   @Override public Entry<K, V> pollFirstEntry() {
222     return delegate.pollFirstEntry();
223   }
224 
225   @Override public Entry<K, V> pollLastEntry() {
226     return delegate.pollLastEntry();
227   }
228 
229   @Override public V put(K key, V value) {
230     return delegate.put(checkValid(key), value);
231   }
232 
233   @Override public void putAll(Map<? extends K, ? extends V> map) {
234     for (K key : map.keySet()) {
235       checkValid(key);
236     }
237     delegate.putAll(map);
238   }
239 
240   @Override public V remove(Object key) {
241     return delegate.remove(checkValid(key));
242   }
243 
244   @Override public int size() {
245     return delegate.size();
246   }
247 
248   @Override public NavigableMap<K, V> subMap(
249       K fromKey, boolean fromInclusive, K toKey, boolean toInclusive) {
250     return new SafeTreeMap<K, V>(delegate.subMap(
251         checkValid(fromKey), fromInclusive, checkValid(toKey), toInclusive));
252   }
253 
254   @Override public SortedMap<K, V> subMap(K fromKey, K toKey) {
255     return subMap(fromKey, true, toKey, false);
256   }
257 
258   @Override public SortedMap<K, V> tailMap(K fromKey) {
259     return tailMap(fromKey, true);
260   }
261 
262   @Override public NavigableMap<K, V> tailMap(K fromKey, boolean inclusive) {
263     return new SafeTreeMap<K, V>(
264         delegate.tailMap(checkValid(fromKey), inclusive));
265   }
266 
267   @Override public Collection<V> values() {
268     return delegate.values();
269   }
270 
271   private <T> T checkValid(T t) {
272     // a ClassCastException is what's supposed to happen!
273     @SuppressWarnings("unchecked")
274     K k = (K) t;
275     comparator().compare(k, k);
276     return t;
277   }
278 
279   @Override public boolean equals(Object obj) {
280     return delegate.equals(obj);
281   }
282 
283   @Override public int hashCode() {
284     return delegate.hashCode();
285   }
286 
287   @Override public String toString() {
288     return delegate.toString();
289   }
290 
291   private static final long serialVersionUID = 0L;
292 }